/*
@file: ast.h
@author: ZZH
@time: 2021-12-27 15:26:01
@info: 抽象语法树实现
*/
#pragma once
#include <string>
#include <list>
#include <cstdint>
#include <iostream>
#include "operations.h"
#include "symTable.h"
#include "ErrInfo.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Function.h"

extern llvm::Module* theModule;
extern llvm::LLVMContext theContext;
extern llvm::IRBuilder<> Builder;

using namespace std;

namespace dataTypes
{
    typedef enum
    {
        None,
        Auto,
        Bool,
        Byte,
        Int,
        Float,
        String,
    } dataType_t;
    extern const char* dataType2Str[];
} // namespace dataTypes

/*语法树节点基类*/
class ASTNode_t
{
private:
    void printBasic() const { RAW("[ASTNode]\r\n    nodeType:%s @ %p", this->getNodeTypeStr(), this); }
protected:
    typedef enum
    {
        None,//错误类型
        Expression,//表达式类型
        VariableDef,//变量定义
        FunctionDef,//函数定义
    } nodeType_t;
public:
    ASTNode_t(nodeType_t type = ASTNode_t::None):nodeType(type) {}
    virtual ~ASTNode_t() {}

    nodeType_t nodeType;
    static const char* nodeType2Str[];
    inline const char* getNodeTypeStr() const { return this->nodeType2Str[this->nodeType]; }
    inline bool isNodeType(nodeType_t type) { return this->nodeType == type; }

    virtual void printInfo() const { this->printBasic(); }

    virtual llvm::Value* codeGen() = 0;
};

/*变量定义节点,此节点用于符号表*/
class ASTVariableDef_t:public ASTNode_t
{
private:

protected:

public:
    //todo 词法分析器里使用了strdup,理论上说这里可以直接使用指针,然后在析构函数里释放,这样可以避免字符串的复制
    string varName;//变量名
    dataTypes::dataType_t varType;
    uint32_t arrLen;//如果是数组，此处不为0
    uint32_t addr;//偏移地址，在链接时确定

    inline const char* getVarTypeStr() const { return dataTypes::dataType2Str[this->varType]; }

    ASTVariableDef_t(const string& name, dataTypes::dataType_t type = dataTypes::Auto, uint32_t arrLen = 0);
    virtual ~ASTVariableDef_t() {}

    bool isSame(const ASTVariableDef_t& other) const;

    inline bool operator== (const ASTVariableDef_t& other) const { return this->isSame(other); }

    virtual void printInfo() const;

    virtual llvm::Value* codeGen() override;
};

/*表达式节点*/
class ASTExpression_t: public ASTNode_t
{
private:

protected:
    enum expType_t
    {
        Eval,
        Number,
        String,
        VariableCall,
        FunctionCall,
        Block
    }expType;
    static const char* expType2Str[];
    inline const char* getExpTypeStr() const { return this->expType2Str[this->expType]; }
    inline bool isExpType(expType_t type) { return this->expType == type; }
public:
    ASTExpression_t(expType_t type):ASTNode_t(ASTNode_t::Expression), expType(type) {}
    virtual ~ASTExpression_t() {}

    virtual void printInfo() const;
};

//区块，也就是花括号括起来的表达式列表
class ASTBlockStatement_t: public ASTExpression_t
{
public:
    using LocalVariableTable_t = SymTable_t<ASTVariableDef_t*>;
    using ExpressionList_t = list<ASTExpression_t*>;
protected:
    LocalVariableTable_t localVarTable;//局部变量表
    ExpressionList_t expList;//区块内的表达式列表
public:
    ASTBlockStatement_t():ASTExpression_t(ASTExpression_t::Block) {}
    ~ASTBlockStatement_t()
    {
        cout << "del blk exp" << endl;
        for (auto exp : this->expList)
            delete exp;
    }

    //插入一个表达式
    bool insert(ASTExpression_t* pExp);

    //插入一个变量定义
    inline bool insert(ASTVariableDef_t* pVarDef) { return this->localVarTable.insert(pVarDef->varName, pVarDef); }

    inline const ExpressionList_t& getExpList() const { return this->expList; }
    inline const LocalVariableTable_t& getLocalVarTable() const { return this->localVarTable; }

    virtual void printInfo() const;

    virtual llvm::Value* codeGen() override;
};

/*函数定义节点,此节点用于符号表*/
class ASTFunctionDef_t:public ASTNode_t
{
public:
    using vArgSymTable_t = ASTBlockStatement_t::LocalVariableTable_t;
protected:
    vArgSymTable_t* pVArgs;//形式参数
    uint32_t addr;//偏移地址
    ASTBlockStatement_t* funcBody;//函数体为一个区块
    dataTypes::dataType_t retType;//返回值类型

public:
    //todo 词法分析器里使用了strdup,理论上说这里可以直接使用指针,然后在析构函数里释放,这样可以避免字符串的复制
    string funcName;//函数名
    ASTFunctionDef_t(const string& funcName, vArgSymTable_t* pVArgs, ASTBlockStatement_t* funcBody, dataTypes::dataType_t retType);
    ~ASTFunctionDef_t()
    {
        cout << "del funcDef" << endl;
        if (this->pVArgs)//如果函数定义时提供了参数
            delete this->pVArgs;

        if (this->funcBody)
            delete this->funcBody;
    }
    //将函数体指针设为空,避免析构时释放,这样一来其他函数定义对象还可以继续使用之前创建的函数体对象
    inline void keepFuncBody() { this->funcBody = nullptr; }

    //函数定义节点也用于声明,因此需要判断
    inline bool isRef() { return this->funcBody == nullptr; }

    //先定义后声明时,检查声明和定义是否吻合
    void check(const ASTFunctionDef_t* other) const;

    //没有声明过,直接定义函数时,直接将声明的节点作为定义的节点,然后把函数体填入即可
    inline void update(ASTBlockStatement_t* newFuncBody) { this->funcBody = newFuncBody; }

    //使用函数定义更新函数表内的声明,此时不仅重新指定了函数体,还需要检查定义和声明是否一致
    inline void updateAndCheck(const ASTFunctionDef_t* other)
    {
        this->check(other);
        this->update(other->funcBody);
    }

    virtual void printInfo() const;

    virtual llvm::Function* codeGen() override;
};

/*变量调用*/
class ASTVariableCall_t:public ASTExpression_t
{
private:

protected:
    string varName;
public:
    ASTVariableCall_t(const string& varName):ASTExpression_t(ASTExpression_t::VariableCall), varName(varName) {}
    ASTVariableCall_t(const ASTVariableDef_t& other):ASTVariableCall_t(other.varName) {}
    ASTVariableCall_t(const ASTVariableDef_t* other):ASTVariableCall_t(other->varName) {}
    ~ASTVariableCall_t() {}

    virtual void printInfo() const;
    virtual llvm::Value* codeGen() override;
};

/*函数调用*/
class ASTFunctionCall_t:public ASTExpression_t
{
private:

protected:
    string funcName;//调用的函数名称
    list<ASTExpression_t*>* pArgs;//调用函数所使用的实际参数
public:
    ASTFunctionCall_t(const string& funcName, list<ASTExpression_t*>* pArgs):ASTExpression_t(ASTExpression_t::FunctionCall), funcName(funcName), pArgs(pArgs) {}
    ~ASTFunctionCall_t()
    {
        cout << "del pArgs" << endl;
        if (this->pArgs)
        {
            for (auto exp : *this->pArgs)
            {
                if (exp->nodeType == ASTNode_t::Expression)
                {
                    delete exp;
                }
            }
            delete this->pArgs;
        }
    }

    virtual void printInfo() const;
    virtual llvm::Value* codeGen() override;
};

/*二元运算符节点*/
class ASTOperatorDouble_t:public ASTExpression_t
{
private:

protected:
    ASTExpression_t* left_exp, * right_exp;//左子式和右子式
    char op;//运算符
public:
    ASTOperatorDouble_t(char op, ASTExpression_t* lexp, ASTExpression_t* rexp);
    virtual ~ASTOperatorDouble_t() { if (left_exp->isNodeType(ASTNode_t::Expression)) delete left_exp; if (right_exp->isNodeType(ASTNode_t::Expression)) delete right_exp; }

    virtual void printInfo() const;
    virtual llvm::Value* codeGen() override;
};

/*整数节点*/
class ASTNumber_t:public ASTExpression_t
{
private:

protected:
    uint32_t value;//整数数值
public:
    ASTNumber_t(uint32_t value):ASTExpression_t(ASTExpression_t::Number), value(value) {}
    virtual ~ASTNumber_t() {}

    virtual void printInfo() const;

    virtual llvm::Value* codeGen() override;
};

/*字符串节点*/
class ASTString_t:public ASTExpression_t
{
private:

protected:
    char* str;//字符串地址
public:
    ASTString_t(const char* __str):ASTExpression_t(ASTExpression_t::String)
    {
        this->str = new char[strlen(__str)];
        strcpy(this->str, __str);
    }
    ~ASTString_t()
    {
        if (this->str)
            delete this->str;
    }
    virtual void printInfo() const;

    virtual llvm::Value* codeGen() override;
};

class ASTReturnExp_t:public ASTExpression_t
{
private:
    ASTExpression_t* rtExp;
protected:

public:
    ASTReturnExp_t(ASTExpression_t* exp):ASTExpression_t(ASTExpression_t::Number), rtExp(exp) {}
    ~ASTReturnExp_t()
    {
        if (nullptr != this->rtExp)
            delete this->rtExp;
    }

    virtual void printInfo() const
    {
        if (nullptr != this->rtExp)
        {
            this->rtExp->printInfo();
        }
        else
        {
            cout << "return a empty exp? it`s impossible!" << endl;
        }
    }

    virtual llvm::Value* codeGen() override;
};
